#!/usr/bin/env python3
"""
gpt2 class

reference: https://github.com/karpathy/nanoGPT
"""

import math
from collections import OrderedDict

import torch
import torch.nn as nn

# from modern_hopfield_attention.layers import SelfAttention


class _SelfAttention(nn.Module):
    def __init__(
        self,
        sequence_length: int,
        embedding_dim: int,
        num_heads: int,
        qk_norm: bool = False,
        norm_layer: nn.Module = nn.LayerNorm,
        dropout_prob: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__()
        # atribute
        self.num_heads = num_heads

        # layer
        self.qkv = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
        self.q_norm = norm_layer(self.num_heads) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.num_heads) if qk_norm else nn.Identity()

        self.attn_dropout = nn.Dropout(p=dropout_prob)

        self.proj = nn.Linear(embedding_dim, embedding_dim, bias=bias)
        self.proj_dropout = nn.Dropout(p=dropout_prob)

        # register_buffer
        self.register_buffer("attn_mask", self._generate_attn_mask(sequence_length))

        # hook
        # self.hook_input = None
        # self.hook_q = None
        # self.hook_k = None
        # self.hook_qk_logit = None
        # self.hook_qk_logit_masked = None
        # self.hook_qk = None
        # self.hook_output = None
        # self.hook_output_proj = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # if not self.training:
        #     self.hook_input = x.clone().detach()

        B, T, C = x.shape  # batch size, sequence length, embedding dimensionality

        q, k, v = self.qkv(x).split(C, dim=-1)

        # reshpae: (B, T, C)->(B, nh, T, hs)
        q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)

        # normalize
        q, k = self.q_norm(q), self.k_norm(k)

        # if not self.training:
        #     self.hook_q = q.clone()
        #     self.hook_k = k.clone()
        # scaling
        q = q * (1 / math.sqrt(C // self.num_heads))

        # attention
        attn = q @ k.transpose(-2, -1)
        # if not self.training:
        #     self.hook_qk_logit = attn.clone()

        # add attention mask
        attn = attn.masked_fill(self.attn_mask[:, :, :T, :T] == 0, float("-inf"))

        # if not self.training:
        #     self.hook_qk_logit_masked = attn.clone()

        attn = attn.softmax(-1)
        # if not self.training:
        #     self.hook_qk = attn.clone()

        attn = self.attn_dropout(attn)

        x = attn @ v
        x = x.transpose(1, 2).contiguous().view(B, T, C)
        # if not self.training:
        #     self.hook_output = x.clone()

        x = self.proj(x)
        x = self.proj_dropout(x)

        # if not self.training:
        #     self.hook_output_proj = x.clone()

        return x


    def _generate_attn_mask(self, sequence_length: int):
        attn_mask = torch.tril(torch.ones(sequence_length, sequence_length)).view(
            1, 1, sequence_length, sequence_length
        )
        if str(next(self.parameters()).device).startswith("mps"):
            attn_mask = torch.nan_to_num(attn_mask, nan=0.0)

        return attn_mask


class _MLP(nn.Sequential):
    def __init__(
        self,
        embedding_dim: int,
        dropout_prob: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__(
            OrderedDict(
                fc=nn.Linear(embedding_dim, 4 * embedding_dim, bias=bias),
                gelu=nn.GELU(),
                proj=nn.Linear(4 * embedding_dim, embedding_dim, bias=bias),
                dropout=nn.Dropout(dropout_prob),
            )
        )


class _Block(nn.Module):
    def __init__(
        self,
        sequence_length: int,
        embedding_dim: int,
        num_heads: int,
        dropout_prob: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__()
        self.ln_1 = nn.LayerNorm(embedding_dim, bias=bias)
        # self.attn = SelfAttention(
        #     num_tokens=sequence_length,
        #     dim=embedding_dim,
        #     num_heads=num_heads,
        #     attn_drop=dropout_prob,
        #     causal=True,
        # )
        self.attn = _SelfAttention(
            sequence_length=sequence_length,
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            dropout_prob=dropout_prob,
            bias=bias,
        )
        self.ln_2 = nn.LayerNorm(embedding_dim, bias=bias)
        self.mlp = _MLP(
            embedding_dim=embedding_dim,
            dropout_prob=dropout_prob,
            bias=bias,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT2_ori(nn.Module):

    def __init__(
        self,
        vocab_size: int,
        sequence_length: int,
        embedding_dim: int,
        num_heads: int,
        depth: int,
        dropout_prob: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__()

        self.transformer = nn.ModuleDict(
            OrderedDict(
                token_embedding=nn.Embedding(vocab_size, embedding_dim),
                position_embedding=nn.Embedding(sequence_length, embedding_dim),
                dropout=nn.Dropout(dropout_prob),
                block=nn.ModuleDict(
                    OrderedDict(
                        {
                            f"{i}": _Block(
                                sequence_length,
                                embedding_dim,
                                num_heads,
                                dropout_prob,
                                bias,
                            )
                            for i in range(depth)
                        }
                    )
                ),
                ln=nn.LayerNorm(embedding_dim, bias),
            )
        )
        self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False)

        # tying token_embedding's weight and lm_head's
        self.transformer.token_embedding.weight = self.lm_head.weight
        # apply init-weight
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("proj.weight"):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * depth))

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
            module.bias.data.zero_()

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        B, T = x.shape  # B:batch-size, T:sequence-length
        position = torch.arange(0, T, dtype=torch.long, device=x.device)

        x = self.transformer.token_embedding(x)
        position = self.transformer.position_embedding(position)
        x = x + position
        x = self.transformer.dropout(x)
        for block in self.transformer.block.values():
            x = block(x)
        x = self.transformer.ln(x)

        logits = self.lm_head(x)
        return logits

    def register_hooks(self) -> None:
        self.hook_input = list()
        # self.hook_q = list()
        # self.hook_k = list()
        # self.hook_qk_logit = list()
        # self.hook_qk_logit_masked = list()
        # self.hook_qk = list()
        # self.hook_output = list()

        def hook_fn(module, input, output) -> None:
            if isinstance(module, _SelfAttention):
                self.hook_input.append(input[0].detach().cpu())
                # self.hook_q.append(module.hook_q.detach().cpu())
                # self.hook_k.append(module.hook_k.detach().cpu())
                # self.hook_qk_logit.append(module.hook_qk_logit.detach().cpu())
                # self.hook_qk.append(module.hook_qk.detach().cpu())
                # self.hook_output.append(module.hook_output.detach().cpu())

        for block in self.transformer.block.values():
            block.attn.register_forward_hook(hook_fn)

    def clear_hooks(self) -> None:
        self.hook_input = list()
        # self.hook_q = list()
        # self.hook_k = list()
        # self.hook_qk_logit = list()
        # self.hook_qk = list()
        # self.hook_output = list()


# if __name__ == "__main__":
#     model = GPT2_ori(50304, 1024, 768, 12, 12)

#     from torchinfo import summary

#     summary(model)
